from typing import *
from torch_geometric.typing import *

import torch
from torch import Tensor
from torch_sparse import SparseTensor


class {{cls_name}}(torch.nn.Module):
    def reset_parameters(self):
        for module in self.children():
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()

    def forward(self, {{ input_args|join(', ') }}):
        """"""{% for call in calls %}
        {% for output in call[2] %}{{output}}{{ ", " if not loop.last }}{% endfor %} = self.module_{{loop.index - 1}}({% for input in call[1] %}{{input}}{{ ", " if not loop.last }}{% endfor %}){% endfor %}
        return {% for output in calls[-1][2] %}{{output}}{{ ", " if not loop.last }}{% endfor %}

    def __repr__(self) -> str:
        return 'Sequential(\n{}\n)'.format('\n'.join(
            [f'  ({i}): ' + str(getattr(self, f'module_{i}')) for i in range({{calls|length}})]))
